{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Use SARSA to Play Taxi-v3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import sys\n",
    "import logging\n",
    "import itertools\n",
    "\n",
    "import numpy as np\n",
    "np.random.seed(0)\n",
    "import gym\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "logging.basicConfig(level=logging.INFO,\n",
    "        format='%(asctime)s [%(levelname)s] %(message)s',\n",
    "        stream=sys.stdout, datefmt='%H:%M:%S')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "00:00:00 [INFO] env: <TaxiEnv<Taxi-v3>>\n",
      "00:00:00 [INFO] action_space: Discrete(6)\n",
      "00:00:00 [INFO] observation_space: Discrete(500)\n",
      "00:00:00 [INFO] reward_range: (-inf, inf)\n",
      "00:00:00 [INFO] metadata: {'render.modes': ['human', 'ansi']}\n",
      "00:00:00 [INFO] _max_episode_steps: 200\n",
      "00:00:00 [INFO] _elapsed_steps: None\n",
      "00:00:00 [INFO] id: Taxi-v3\n",
      "00:00:00 [INFO] entry_point: gym.envs.toy_text:TaxiEnv\n",
      "00:00:00 [INFO] reward_threshold: 8\n",
      "00:00:00 [INFO] nondeterministic: False\n",
      "00:00:00 [INFO] max_episode_steps: 200\n",
      "00:00:00 [INFO] _kwargs: {}\n",
      "00:00:00 [INFO] _env_name: Taxi\n"
     ]
    }
   ],
   "source": [
    "env = gym.make('Taxi-v3')\n",
    "env.seed(0)\n",
    "for key in vars(env):\n",
    "    logging.info('%s: %s', key, vars(env)[key])\n",
    "for key in vars(env.spec):\n",
    "    logging.info('%s: %s', key, vars(env.spec)[key])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SARSAAgent:\n",
    "    def __init__(self, env):\n",
    "        self.gamma = 0.99\n",
    "        self.learning_rate = 0.2\n",
    "        self.epsilon = 0.01\n",
    "        self.action_n = env.action_space.n\n",
    "        self.q = np.zeros((env.observation_space.n, env.action_space.n))\n",
    "\n",
    "    def reset(self, mode=None):\n",
    "        self.mode = mode\n",
    "        if self.mode == 'train':\n",
    "            self.trajectory = []\n",
    "\n",
    "    def step(self, observation, reward, done):\n",
    "        if self.mode == 'train' and np.random.uniform() < self.epsilon:\n",
    "            action = np.random.randint(self.action_n)\n",
    "        else:\n",
    "            action = self.q[observation].argmax()\n",
    "        if self.mode == 'train':\n",
    "            self.trajectory += [observation, reward, done, action]\n",
    "            if len(self.trajectory) >= 8:\n",
    "                self.learn()\n",
    "        return action\n",
    "\n",
    "    def close(self):\n",
    "        pass\n",
    "\n",
    "    def learn(self):\n",
    "        state, _, _, action, next_state, reward, done, next_action = \\\n",
    "                        self.trajectory[-8:]\n",
    "\n",
    "        target = reward + self.gamma * \\\n",
    "                self.q[next_state, next_action] * (1. - done)\n",
    "        td_error = target - self.q[state, action]\n",
    "        self.q[state, action] += self.learning_rate * td_error\n",
    "\n",
    "\n",
    "agent = SARSAAgent(env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "00:00:00 [INFO] ==== train ====\n",
      "00:00:25 [INFO] ==== test ====\n",
      "00:00:25 [INFO] average episode reward = 7.75 ± 2.41\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD4CAYAAAAEhuazAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAenElEQVR4nO3de3hU5bn38e+dkAOnQDAcQgIkAqIBFCVSULEeUNBWQd27pV4WW9tS2fZgd3er1G61B3b72ret9a3a0ta63bYibbWw6xm1ai2IQcEAEogQDIRDOCSEQyCH+/1jFjjiJCFMkoGs3+e65mLN/ay15nkg/LLmmTVrmbsjIiLhkpToDoiISMdT+IuIhJDCX0QkhBT+IiIhpPAXEQmhLonuwLHKysryvLy8RHdDROSksmzZsh3u3vfo+kkT/nl5eRQVFSW6GyIiJxUz2xirrmkfEZEQUviLiISQwl9EJIQU/iIiIaTwFxEJIYW/iEgIKfxFRELopDnP/3i5Oy+XbOf3r5fx2rodANwwfjDVB+q56sxsuqV24UBdA8+u3Mqqimpuu+J06uobefzNctJSkhiU2Y3izdWcm9eHYf168JPnShid04unirfQu1sKN3xsCFUHDrG6Yg8XjehHdq90Nu7czwN/L+XWSafxxoadrNlSQ9+eaaSlJLOivIpnvj6RfQfr+fzDb3J5wQAa3RnYO51Gh2vOzmFFeRUH6xspKtvF5JED+Lc/vkWfbqn85ycLuPXx5R8Z45SRA9i9/xBLy3bhDoVDMpk+bjCPv/k+5w3N4hcvrmPuZ8fycsl2Hltazl1XFVC8uZon3trMeUNP4Xc3nktF9QH+XlLJK2srOXdIJvsONXDdOTnMLyqnaONuks3I6JrC4D7d+JexueT07sqflpXT6DBmUG/uWrCKkm013PvpMby4Zjv/u6KCH107mjfLdlG+az9vlu0GYPLI/nRP7cLYvEy2Vdcyfugp/OqV9by6tpLULkmcmtWdOz9ZwLD+PViyfhc90pK559kSLj69H/UNjZgZO/ceYszg3ozO6cXC5RXsqa3juZVbmXHeEN7bvo9X1lZSW9/AsL49WLd9L1eOHkBqchKHGhp5ungrAL+8/mwM4+WS7byxYSdV++rI79udXl1T2LhzPzv2HmT/oQb++6ZxjM7pRXpKEm+/X8X/e2kdI/r3ZMSADDZX7eeF1dsozOvDtDE5ZHZLYVPVAdK6JHH9b94gp3dXzsjuyeic3ix6dxsAxZurueqsgfx9zXZqDtZz5egBbK2uZUCvdG46P58e6V2Ycu9rAEwdM5AFyyuO/DuP6N+TS8/oxxNvbeZnnzqL/1mykZfWbCc1OYmvXjqMicP78vb7VVQdOMTvXtvApDP6A1C2cx9vbNjF587LY+yQTK4cnc3zq7ZStnM/tXUNVB+oA+Dp4i1srzlIdq908rO688/3dgLwX9eM5jtPFnPt2TnkZ3Xnpy+sJbNbCjedn8/i9TvplprMone3A3D+sFPI6d2V0bm92XOgjqeLt7CqYg8/nDaKNVv3sKWqlkF9urHo3W1ceno/rjprIKXb9/LEW5sp27mPlOQkNlcdYOyQTMbl92Hh8go2Vx0A4CsXD2PwKd2Y++p6tu2ppaa2nuvOyWXXvoMM6JXOY0vLj/xdXTFqAMvLqxgxoCfD+vagbOc+Lh85gLVba8jqmcaWqgP89+KNjOjfk+9PHUmXZOO6BxeT2iWJH04bRW1dAxeP6Mf8onI2Vx2gf0Y6148bzNxX13P+sCxufnTZh/4PfvOy06g+UMcFw7MY1KcbAGu31rB2215mTBhCgzvlu/bz61fWs2vfIbJ7p5NkxpNvb+Yz4wZzzuDe/PrV9QzISGdz1QG2VB/g5o8P5d5F67j7qgI+d37+8UZgk+xkuZ5/YWGhH8+XvB5/831u+0txO/RIRKRjbPjRlZjZcW1rZsvcvfDoesKmfcxsipmVmFmpmd3eXq8z56l322vXIiIdoq6h7Q/SExL+ZpYM3A9cARQAnzGzgvZ4rT219e2xWxGRk1qijvzHAaXuvt7dDwHzgKkJ6ouIyAktJfn4pnyak6jwzwHKo55vCmofYmYzzazIzIoqKys7rHMiIieK/hlpxz3f35xEhX+skXxkUsvd57p7obsX9u37kSuSJkRuZtdEd+FDbptyerPtAzLSOWdw747pTAsGZKQzY8KQRHdD5KTy1Ncmtst+E3Wq5yZgUNTzXKCiiXUT5vtTR3LVmQOpb3T69kxj0+795GZGTuN6pngLs/7wFgN7pTNv5gRSuhj1Dc7vXy8jOQmuOTuXPt1TqWtopEuykZqcRNnOfZwzOJOni7cy7833ueuqAiprDvGZ3yw58prP3jqR0wdkUFlzkGUbdx85peylb36crXtqSU1O4t5F6/jBtFHkZ3UHYNZFQ6mta6Cmtp6uqcnU1Tdy+b2vUllzkAdvOIezB2cCkHf7Ux8aX9F3J9HQ6PTPSKdkaw0ryqsYmZPBz19Yx7Xn5PD+rv0UDslkeP+eADz0jw3cdEE+L767jbQuyZw/7BReWVvJ1+ctB+Brlw7nvhfXHdn/qJwM/nzzeazcXM3YIZlHjl5uPC+PHz29huxe6fzPksjVZj9/fh6/f70MgF9MH3Nkn+Py+7B0w64P9XvDj67kH6U7eG3dDmZ9fCjVB+p4u3w333h8BQDzvzyBvj3T6NM9FTPISE8BYHXFHrbV1DIosxu1dQ1k90rHzFiwfDPFm6q5ddJpLN9UxdVnDWTe0ve58LS+3P9yKSMG9KSiqpYvTswnq0fk5yA1OYnnV2/jhvGRX2a/fW096SnJXHx6P15ft4NfvlzKsH49OCO7J7dOOo2U5CTqGhrZufcQ727Zw+cffhOAd+6+nM27D9C7WwqzHn2L5eVV9M9I443vTOKfpTvo1S2F7F5dqWtoZMfeg1Tvr2NUbi/OvPt5AJ7/xoX8/vUNfPrcwayqqKa2rpHuqckU5mVysL6R3N7dOFDXQF1DIwtXVPCliafy7pY9lG7fS3KSMSqnF2ldkhjUp9uRn4+yH3+Cmto65i0t54sT89lSXUvZjn2YGT3Tu2AGKclJfOWPb/HoFz5G355pfPvP73BZQX+G9uvBV//4Nqu37Dmyr8P7/cG0UVw/bjBvvb+bN8t2kdUjjavPGkhdQyM/fX4tPdK68MuXS4/8bGb1SOPZlVs5rX8PLvnpKwDMuWYUE4f1ZcWmKnbtO8SEoaewcnM12/YcJKtHKlePGciWqlq6pSZTfaCOLslJ/KN0B//515UA3HLxUC4e0Y+e6Sn84Y2NfOGCfDK7p7Jp1wFWVVTzrT+/E/n3nFFIft/u/OyFtYwdnIkZFGRn8PA/y+jVNYWROb246sxsSrfv5VevrAecH193Jj3SurCnto5+PdNpaIyc3vntP7/D0rIPfoYvOb0fX77wVN7ZVM3lI/vz61fXU72/LnL6sRmPLtnIZQX92Xuwnt37DpHVI+0j2dQWEnKqp5l1AdYClwKbgTeB6919VVPbHO+pnkcH3rH66iXD+OblI5psL9law+R7X+W7nziDL0489bhe47CRdz7LvkMNQOQ/S7T1lXvZue8Q5+b1adU+Zz26jGdWbuWFb1x4JLxfWrONjTv3873/XU1WjzSKvjsprn4fNvnnr9Kneyq/umEsZ33/eb40MZ/fvLaB6ecO4sfXndnkdve/XMpPnith1kVDuW3K6RSV7SKzeypD+/Zg0eptfPGRIr788VP59SvrAfjuJ85gWL8eXDSiX8z9RYfXiczduee5Ej4xOptROb2O1F9bV8lnf7eU+68/h0+cmd3sPu57cR33v1xKyQ+vaLN+teXfX/X+OpKSoGd6Cnm3P8WQU7rxyrcubnG7afe/zvLyKlbcdTm9uqYcqd+9cBUvrN7G67dfclz9WV+5l0cWb+TOTxaQlBR7CmXe0ve5/YniFn9uW2v7nlrG/deLQOQ7OH+edV6b7ftYNHWqZ8LO8zezK4F7gWTgIXef09z6HRn+/zo2lzuvKqBnekqz622uOsDA4OgxHg2Nztfnvc1XLxnOiAE949rXYTW1dby2bgdXjv5oiOzed4jULkl0T2u/N34LV1Rw2Rn96Zqa3OQ667bVcNnPX+Wpr13AyIG9PtTm7vx1+WauHJ3N9j0HKd+9n/OGZjX7mis3VwN8KFBPNpurDpDTOzFTi+31y/PlNdsZOTCDfhnpLa67a98hlm3czWUF/du0D8eitq6Buxas4rYrTqdP99Q23fcb63fy6blLODcvkz/dHPLwb632Dv8Jp57C4vWRbzSe6EeOIu3hZHnndDJasn4n0+cuYVxeH+bfPKFDX/uE+5LXiSZJfxMSclk92vZoVz4wIph6/eLEtr9Mw/Hq9Nf2OVYW8wQkkfB45VsXU9fQmOhudEqZ3VNPuHdUCv9AO5xGK3JSac/PgOTEo8mOwE3tcNU8EZETVah/1Y/KyeDas3O56YJI8D/2pfEJO9NCRKQjhTr8H/vS+A+dzjlh6CkJ7I2ISMcJ9bRPS+fxi4h0VqEOfxGRsFL4i4iEkMJfRCSEFP4iIiGk8BcRCSGFv4hICCn8RURCSOEvIhJCCn8RkRCKK/zN7F/NbJWZNZpZ4VFts82s1MxKzGxyVH2smRUHbfdZe9yWXkREmhXvkf9K4Frg1eiimRUA04GRwBTgATM7fD+/B4GZwPDgMSXOPoiISCvFFf7u/q67l8RomgrMc/eD7r4BKAXGmVk2kOHuiz1y/8hHgGnx9EFERFqvveb8c4DyqOebglpOsHx0PSYzm2lmRWZWVFlZ2S4dFREJoxYv6Wxmi4ABMZrucPcFTW0Wo+bN1GNy97nAXIjcwL2FroqIyDFqMfzdfdJx7HcTMCjqeS5QEdRzY9RFRKQDtde0z0JgupmlmVk+kQ92l7r7FqDGzMYHZ/nMAJp69yAiIu0k3lM9rzGzTcAE4Ckzew7A3VcB84HVwLPALe7eEGw2C/gtkQ+B3wOeiacPIiLSenHdxtHdnwSebKJtDjAnRr0IGBXP64qISHz0DV8RkRBS+IuIhJDCX0QkhBT+IiIhpPAXEQkhhb+ISAgp/EVEQkjhLyISQgp/EZEQUviLiISQwl9EJIQU/iIiIaTwFxEJIYW/iEgIKfxFREJI4S8iEkLx3snrJ2a2xszeMbMnzax3VNtsMys1sxIzmxxVH2tmxUHbfcHtHEVEpAPFe+T/AjDK3c8E1gKzAcysAJgOjASmAA+YWXKwzYPATCL39R0etIuISAeKK/zd/Xl3rw+eLgFyg+WpwDx3P+juG4jcr3ecmWUDGe6+2N0deASYFk8fRESk9dpyzv8mPrgZew5QHtW2KajlBMtH12Mys5lmVmRmRZWVlW3YVRGRcGvxBu5mtggYEKPpDndfEKxzB1AP/OHwZjHW92bqMbn7XGAuQGFhYZPriYhI67QY/u4+qbl2M7sR+CRwaTCVA5Ej+kFRq+UCFUE9N0ZdREQ6ULxn+0wBbgOudvf9UU0LgelmlmZm+UQ+2F3q7luAGjMbH5zlMwNYEE8fRESk9Vo88m/BL4E04IXgjM0l7n6zu68ys/nAaiLTQbe4e0OwzSzgYaArkc8InvnIXkVEpF3FFf7uPqyZtjnAnBj1ImBUPK8rIiLxCd03fIvvvjzRXRARSbjQhX/P9JREd0FEJOFCF/4iIqLwFxEJJYW/iEgIKfxFREIoVOGfnhKq4YqINCneL3mdlJ7/xoXUNTQmuhsiIgkTyvA/rX/PRHdBRCShQjUPkpvZLdFdEBE5IYQq/L939chEd0FE5IQQqvDPy+qe6C6IiJwQQhX+Ob27JroLIiInhFCFv4iIRCj8RURCSOEvIhJC8d7G8Qdm9o6ZLTez581sYFTbbDMrNbMSM5scVR9rZsVB233B7RxFRKQDxXvk/xN3P9PdxwB/A+4EMLMCYDowEpgCPGBmycE2DwIzidzXd3jQLiIiHSiu8Hf3PVFPuwMeLE8F5rn7QXffAJQC48wsG8hw98Xu7sAjwLR4+iAiIq0X9+UdzGwOMAOoBi4OyjnAkqjVNgW1umD56HpT+55J5F0CgwcPjrerIiISaPHI38wWmdnKGI+pAO5+h7sPAv4AfOXwZjF25c3UY3L3ue5e6O6Fffv2bXk0IiJyTFo88nf3Sce4rz8CTwF3ETmiHxTVlgtUBPXcGHUREelA8Z7tMzzq6dXAmmB5ITDdzNLMLJ/IB7tL3X0LUGNm44OzfGYAC+Lpg4iItF68c/4/NrMRQCOwEbgZwN1Xmdl8YDVQD9zi7g3BNrOAh4GuwDPBQ0REOlBc4e/u1zXTNgeYE6NeBIyK53VFRCQ++oaviEgIKfxFREJI4S8iEkIKfxGREFL4i4iEkMJfRCSEFP4iIiGk8BcRCSGFv4hICCn8RURCSOEvIhJCCn8RkRBS+IuIhJDCX0QkhBT+IiIhpPAXEQmhNgl/M/sPM3Mzy4qqzTazUjMrMbPJUfWxZlYctN0X3M5RREQ6UNzhb2aDgMuA96NqBcB0YCQwBXjAzJKD5geBmUTu6zs8aBcRkQ7UFkf+Pwe+DXhUbSowz90PuvsGoBQYZ2bZQIa7L3Z3Bx4BprVBH0REpBXiCn8zuxrY7O4rjmrKAcqjnm8KajnB8tH1pvY/08yKzKyosrIynq6KiEiUFm/gbmaLgAExmu4AvgNcHmuzGDVvph6Tu88F5gIUFhY2uZ6IiLROi+Hv7pNi1c1sNJAPrAg+s80F3jKzcUSO6AdFrZ4LVAT13Bh1ERHpQMc97ePuxe7ez93z3D2PSLCf4+5bgYXAdDNLM7N8Ih/sLnX3LUCNmY0PzvKZASyIfxgiItIaLR75Hw93X2Vm84HVQD1wi7s3BM2zgIeBrsAzwUNERDpQm4V/cPQf/XwOMCfGekXAqLZ6XRERaT19w1dEJIQU/iIiIaTwFxEJIYW/iEgIKfxFREJI4S8iEkIKfxGREFL4i4iEkMJfRCSEFP4iIiGk8BcRCSGFv4hICCn8RURCSOEvIhJCCn8RkRBS+IuIhFBc4W9md5vZZjNbHjyujGqbbWalZlZiZpOj6mPNrDhouy+4naOIiHSgtjjy/7m7jwkeTwOYWQEwHRgJTAEeMLPkYP0HgZlE7us7PGgXEZEO1F7TPlOBee5+0N03AKXAODPLBjLcfbG7O/AIMK2d+iAiIk1oi/D/ipm9Y2YPmVlmUMsByqPW2RTUcoLlo+sxmdlMMysys6LKyso26KqIiMAxhL+ZLTKzlTEeU4lM4QwFxgBbgJ8e3izGrryZekzuPtfdC929sG/fvi11VUREjlGXllZw90nHsiMz+w3wt+DpJmBQVHMuUBHUc2PURUSkA8V7tk921NNrgJXB8kJgupmlmVk+kQ92l7r7FqDGzMYHZ/nMABbE0wcREWm9Fo/8W3CPmY0hMnVTBnwZwN1Xmdl8YDVQD9zi7g3BNrOAh4GuwDPBQ0REOlBc4e/un22mbQ4wJ0a9CBgVz+uKiEh89A1fEZEQUviLiISQwl9EJIQU/iIiIaTwFxEJIYW/iEgIKfxFREJI4S8iEkIKfxGREFL4i4iEkMJfRCSEFP4iIiGk8BcRCSGFv4hICCn8RURCSOEvIhJCcYe/mX3VzErMbJWZ3RNVn21mpUHb5Kj6WDMrDtruC27nKCIiHSiuO3mZ2cXAVOBMdz9oZv2CegEwHRgJDAQWmdlpwa0cHwRmAkuAp4Ep6FaOIiIdKt4j/1nAj939IIC7bw/qU4F57n7Q3TcApcC44IbvGe6+2N0deASYFmcfRESkleIN/9OAiWb2hpm9YmbnBvUcoDxqvU1BLSdYProek5nNNLMiMyuqrKyMs6siInJYi9M+ZrYIGBCj6Y5g+0xgPHAuMN/MTgVizeN7M/WY3H0uMBegsLCwyfVERKR1Wgx/d5/UVJuZzQKeCKZwlppZI5BF5Ih+UNSquUBFUM+NURcRkQ4U77TPX4FLAMzsNCAV2AEsBKabWZqZ5QPDgaXuvgWoMbPxwVk+M4AFcfZBRERaKa6zfYCHgIfMbCVwCLgxeBewyszmA6uBeuCW4EwfiHxI/DDQlchZPjrTR0Skg8UV/u5+CLihibY5wJwY9SJgVDyvKyIi8dE3fEVEQkjhLyISQgp/EZEQUviLiISQwl9EJIQU/iIiIaTwFxEJIYW/iEgIKfxFREJI4S8iEkIKfxGREFL4i4iEkMJfRCSEFP4iIiGk8BcRCSGFv4hICMUV/mb2uJktDx5lZrY8qm22mZWaWYmZTY6qjzWz4qDtvuB2jiIi0oHivZPXpw8vm9lPgepguQCYDowEBgKLzOy04FaODwIzgSXA08AUdCtHEZEO1SbTPsHR+6eAx4LSVGCeux909w1AKTDOzLKBDHdfHNzr9xFgWlv0QUREjl1bzflPBLa5+7rgeQ5QHtW+KajlBMtH12Mys5lmVmRmRZWVlW3UVRERaXHax8wWAQNiNN3h7guC5c/wwVE/QKx5fG+mHpO7zwXmAhQWFja5noiItE6L4e/uk5prN7MuwLXA2KjyJmBQ1PNcoCKo58aoi4hIB2qLaZ9JwBp3j57OWQhMN7M0M8sHhgNL3X0LUGNm44PPCWYACz66SxERaU9xne0TmM6Hp3xw91VmNh9YDdQDtwRn+gDMAh4GuhI5y0dn+oiIdLC4w9/dP9dEfQ4wJ0a9CBgV7+uKiMjx0zd8RURCSOEvIhJCCn8RkRBS+IuIhFBowv9j+X0S3QURkRNGaML/8S9PSHQXREROGKEJfxER+YDCX0QkhBT+IiIhpPAXEQkhhb+ISAgp/EVEQkjhLyISQm1xSecT2m9mFNLougmYiEi0Th/+lxX0T3QXREROOJr2EREJobjC38zGmNkSM1tuZkVmNi6qbbaZlZpZiZlNjqqPNbPioO2+4HaOIiLSgeI98r8H+J67jwHuDJ5jZgVEbu84EpgCPGBmycE2DwIzidzXd3jQLiIiHSje8HcgI1juBVQEy1OBee5+0N03AKXAODPLBjLcfbG7O/AIMC3OPoiISCvF+4HvrcBzZvZ/ifwiOS+o5wBLotbbFNTqguWj6zGZ2Uwi7xIYPHhwnF0VEZHDWgx/M1sEDIjRdAdwKfANd/+LmX0K+B0wCYg1j+/N1GNy97nAXIDCwkKdryki0kZaDH93n9RUm5k9Anw9ePon4LfB8iZgUNSquUSmhDYFy0fXRUSkA8U7518BfDxYvgRYFywvBKabWZqZ5RP5YHepu28BasxsfHCWzwxgQZx9EBGRVop3zv9LwC/MrAtQSzA/7+6rzGw+sBqoB25x94Zgm1nAw0BX4Jng0aJly5btMLONx9nPLGDHcW57stFYOyeNtXPqiLEOiVU0D8GlD8ysyN0LE92PjqCxdk4aa+eUyLHqG74iIiGk8BcRCaGwhP/cRHegA2msnZPG2jklbKyhmPMXEZEPC8uRv4iIRFH4i4iEUKcOfzObElxSutTMbk90f46VmT1kZtvNbGVUrY+ZvWBm64I/M6PaWnX57ODLd48H9TfMLK9DBxjFzAaZ2ctm9q6ZrTKzrwf1TjVeM0s3s6VmtiIY5/c64zijmVmymb1tZn8LnnfKsZpZWdDH5WZWFNRO/LG6e6d8AMnAe8CpQCqwAihIdL+Ose8XAucAK6Nq9wC3B8u3A/8nWC4IxpYG5AdjTg7algITiFxT6RngiqD+b8CvguXpwOMJHGs2cE6w3BNYG4ypU4036FOPYDkFeAMY39nGedSY/x34I/C3Tv4zXAZkHVU74ceasB+MDvgHmQA8F/V8NjA70f1qRf/z+HD4lwDZwXI2UBJrXMBzwdizgTVR9c8Av45eJ1juQuQbhpboMQf9WQBc1pnHC3QD3gI+1lnHSeS6XS8SuezL4fDvrGMt46Phf8KPtTNP++QA5VHPm7189Emgv0eujUTwZ7+g3tQ4c2j68tlHtnH3eqAaOKXden6MgrezZxM5Ku504w2mQZYD24EX3L1TjjNwL/BtoDGq1lnH6sDzZrbMIpehh5NgrJ35Bu6tunz0Sex4Lp99wv3dmFkP4C/Are6+x5q+u+dJO16PXN9qjJn1Bp40s1HNrH7SjtPMPglsd/dlZnbRsWwSo3ZSjDVwvrtXmFk/4AUzW9PMuifMWDvzkX9Tl5U+WW2zyJ3QCP7cHtSP5/LZR7axyEX5egG72q3nLTCzFCLB/wd3fyIod9rxunsV8HcitzDtjOM8H7jazMqAecAlZvYonXOsuHtF8Od24ElgHCfBWDtz+L8JDDezfDNLJfJBycIE9ykeC4Ebg+Ub+eBS2Mdz+ezoff0L8JIHE4odLejb74B33f1nUU2darxm1jc44sfMuhK56dEaOtk4Adx9trvnunsekf93L7n7DXTCsZpZdzPreXgZuBxYyckw1kR8QNKBH8RcSeTskfeAOxLdn1b0+zFgCx/c9vILROb4XiRyz4QXgT5R698RjLGE4AyBoF4Y/CC+B/ySD77RnU7k5julRM4wODWBY72AyFvYd4DlwePKzjZe4Ezg7WCcK4E7g3qnGmeMcV/EBx/4drqxEjmbcEXwWHU4Z06GseryDiIiIdSZp31ERKQJCn8RkRBS+IuIhJDCX0QkhBT+IiIhpPAXEQkhhb+ISAj9f5ZS4tWjN7wJAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "def play_episode(env, agent, max_episode_steps=None, mode=None, render=False):\n",
    "    observation, reward, done = env.reset(), 0., False\n",
    "    agent.reset(mode=mode)\n",
    "    episode_reward, elapsed_steps = 0., 0\n",
    "    while True:\n",
    "        action = agent.step(observation, reward, done)\n",
    "        if render:\n",
    "            env.render()\n",
    "        if done:\n",
    "            break\n",
    "        observation, reward, done, _ = env.step(action)\n",
    "        episode_reward += reward\n",
    "        elapsed_steps += 1\n",
    "        if max_episode_steps and elapsed_steps >= max_episode_steps:\n",
    "            break\n",
    "    agent.close()\n",
    "    return episode_reward, elapsed_steps\n",
    "\n",
    "\n",
    "logging.info('==== train ====')\n",
    "episode_rewards = []\n",
    "for episode in itertools.count():\n",
    "    episode_reward, elapsed_steps = play_episode(env.unwrapped, agent,\n",
    "            max_episode_steps=env._max_episode_steps, mode='train')\n",
    "    episode_rewards.append(episode_reward)\n",
    "    logging.debug('train episode %d: reward = %.2f, steps = %d',\n",
    "            episode, episode_reward, elapsed_steps)\n",
    "    if np.mean(episode_rewards[-200:]) > 8.3:\n",
    "        break\n",
    "plt.plot(episode_rewards)\n",
    "\n",
    "\n",
    "logging.info('==== test ====')\n",
    "episode_rewards = []\n",
    "for episode in range(100):\n",
    "    episode_reward, elapsed_steps = play_episode(env, agent)\n",
    "    episode_rewards.append(episode_reward)\n",
    "    logging.debug('test episode %d: reward = %.2f, steps = %d',\n",
    "            episode, episode_reward, elapsed_steps)\n",
    "logging.info('average episode reward = %.2f ± %.2f',\n",
    "        np.mean(episode_rewards), np.std(episode_rewards))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
